{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(70000, 784)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 解压数据集\n",
    "from sklearn.datasets import fetch_mldata\n",
    "mnist = fetch_mldata('MNIST original', data_home='./')\n",
    "X, y = mnist['data'], mnist['target']\n",
    "X.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "建立数据集和测试集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]\n",
    "import numpy as np\n",
    "shuffle_index = np.random.permutation(60000)\n",
    "X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import GridSearchCV\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "# param_grid = [{'weights': [\"uniform\", \"distance\"], 'n_neighbors': [3, 4]}]\n",
    "# knn_clf = KNeighborsClassifier()\n",
    "# grid_search = GridSearchCV(knn_clf, param_grid, cv=2, verbose=3, n_jobs=-1)\n",
    "# grid_search.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "# grid_search.best_params_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "# grid_search.best_score_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 移动像素"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.ndimage.interpolation import shift\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "def shift_image(image, dx, dy):\n",
    "    image = image.reshape((28, 28))\n",
    "    shifted_image = shift(image, [dy, dx], cval=0, mode=\"constant\")\n",
    "    return shifted_image.reshape([-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAp0AAADTCAYAAADDGKgLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAd+0lEQVR4nO3df5RV5X3v8c8XBEPwxyVlIIo/yLJwC22N6AjpxaRxJbSQUJWGxrBuLdqsGLjg0txYyrI/AnjVJCXSuBKCcLWDGkB6xcablZqCBrjWFh2DCsnohRCMoJGZZdYFSgISvvePvdHjzLOZc+bsZ2b2Pu/XWrNmzufsc/ZzYL7zfM8+5znb3F0AAABATAP6egAAAAAoP5pOAAAAREfTCQAAgOhoOgEAABAdTScAAACio+kEAABAdDSd/YiZ7TWzW2u8jZvZzJzHscjMduZ5n0BvqaaOOm9jZu83s38xs/8ws9w/R87MdprZohpvc6uZ7c17LEBRFLGWzex6Mztc432+18z+l5n9v3ROH13nMPstms6cmdkoM1tpZvvM7JiZ7TezVWZ2XhU3v1zS8hp3eY6k/137SIHiMbMmM1ueTjRHzewNM3vCzKbUeFeda+1WSedKukRJTUV5QgcgQS2/y59L+oikK5SM+dWeHIQqgtP6egBlYmYfkPS0pJ9Kmi1pl6SLJN0h6Vkz+z133xu43WB3P+bu7bXu091/Xt+ogUJ5RNJ7JX1W0m5JIyT9vqTfqOVOArX2m5Kec/ddeQwSQLeo5Xf8pqQ2d99xMjCzPhxOPBzpzNc3JZ2Q9HF3f8Ldf+buP5D08TT/piSZ2WYz+5aZLTWzdkn/muadXyYYa2ZbzOxXZvaymX3CzA6b2fUV27z9DM7MRqeXP2VmG83siJn9uPKZo5kNNLP7zOynZvZLM9tlZgvMjN8F9Gtm9p8kfVjSwrS+XnH3Z919qbuv67T5e8zsXjM7mL7q8Bed7uvtWktfwr5a0p+l9dNS8bL2P6bZ3orb/pGZPZfW5U/N7A4zG1xx/Qgz+05aX6+Y2Z9X+fgWmNnP0xp/QNIZna4fYGZ/Y2avpkeGdpjZ1RXXP2xm36q4fEc69kkV2T4z+6/pzy1m9l0zuzl9ReYXZvYPZvbeasYL9FTZaznweDP3Y2abJd0s6SPp+Dan2YWS/i7NSnPqSBqNnJjZ+yRNlfRNdz9SeV16ebmkaWY2LI3/VJIpKbw/C9zfAEmPSjou6UOSrpf0JUmnVzGcOyTdI+mDkp6VtM7MTk5gAyTtl/RpSeMk/ZWk2yTdUOVDBfrK4fTrKjN7TzfbfkHSDkmXSvqKpK+a2e9lbHu5pE2S1it5aevmNJOkz6XZ5ZJkZn8o6duSviHpt5W8LDZT0p0V99ei5MjFxyVdo6S+R59qsGb2aUn/Q0mNXyrpZUn/vdNmN0v6C0l/Kel3lfx92GBml6TXb5Z0ZcX2H5XUcTIzszGSRqXbnfRhSb+TjvVaSTPS/QAxlbaWO6tiP38s6R8k/Vs6vj9Ov/ZJWpJm59Syz37N3fnK4UvSJEkuaUbG9TPS6ycq+aP/YmCbvZJuTX/+QyUN56iK6/9Leh/XV2QuaWb68+j08ucrrh+VZlecYuxflrSp4vIiSTv7+t+UL746f0n6lKQ3Jf1KyR/ppZImddpmr6S1nbJdkv660za3Vlz+rqSWTrd5u7Yqsq2S/qZTdo2SCdQkjU1vN7ni+gsl/VrSolM8rqclreqUbZK0t+Lyfkl/22mbzZIeSn8el+77HCUvWx6VtFDS99PrPydpV8VtWyS9Kum0imxV5d8CvviK9VXiWr5e0uFq95Ne/oakzYHHfmvWfor6xZHO/GUdBrdO1z/Xzf38lqTX3H1/Rfaskpfpu/Nixc+vpd9HvD0Qszlm1mpm7ZassvuCpAuquF+gT7n7I0oWCfyRpH9W8kTs383stk6bvtjp8muqqIE6XCbpr9KXwA+n9bNG0lBJ71fS+J2Q9EzFmF/RO3WYZZySibfS25fN7Cwlj/tfO23zlKTx6X7aJL2h5AjnZEk/kbRO0mQzG5Tmmzvd/sfufrzicl7/TsAplbiWa91PQ2EhUX52KWkof1vSPwWuP3kU4ifp5f/o5v5M2Q1sd946+YO7uyVvSB4gSWZ2raS/V7LC72lJByXNU3IkFuj33P1XkjamX0vM7H9KWmRmS939WLrZW51vpnzeTjRA0mJJ/xi4rl3vPLmMJfQ3oTLbouTl9HZJP3D3vWbWoeQlxd9X8tJ8pVj/TkC3GqSWu9tPQ6HpzIm7v2lm35f038xsmVe8rzN9Y/48Sf+cblfNXbZJGmVm57r7yWdWzaq/2K6QtM3dv1ExvovqvE+gL/1Yyd+y90g61s22tXhL0sBO2Q8l/Za77w7dwMzalNTo5Uqe1MnMLlByROdU2pS8d/v+iuxDJ39w94Nm9pqS+n2yYpsrlDz+kzYreS/oASVPLqWkEb1RXd/PCfQ3Zajlzk65n1M4pq5jLjyaznzNV/LLucnM/lrv/sgkS6+v1kYliwlWpyvzhki6W8n7POtZyfZ/JV1vZtOUfEzFZ5QcAflFHfcJRGdmv6HkaMH9Sl5yO6TkidgCSU+4+8Gcd7lX0sfMbIuko+7+CyVv7P+umb2iZLHCcSULcSa6+wJ3f9nMHpd0r5ndKOmXSur2l93s6+uSHjCzZ5U0hjOVvE/8zYpt/k7J0aBdSt6e86dKFgJdVrHNZiWLFkfrnQZzs5L3au7u9HYdoE+UvJY7O+V+uhnzh83soXTMHTXut1/iZZQcuftPlBTOjyQ9KGmPkvdutEm63N1/WsN9nVDykvfpSt5TslpJ8+pK3njdU/cq+cVfo+Q9oqMlfa2O+wN6y2FJ/65kReoWJXV2p5Lf5Wsj7O+LSl6qflXSdkly9+9L+mSaP5N+LZT0s4rbXa/ks3qfVHLihjVKJpBM7v6wkgV8d6T7+l0lE1yle5Q0nl+VtFPJ34dPufvzFffTJunnkl72dz6/8AdKjphsrupRA/GVtpY7q3I/IX8r6Xwlb8krzcvwJ1dOoQDM7IOSnpfU7O7dLUQCAADoN2g6+zEzm6FkwdEuJUck71byMv0E5z8OAAAUCO/p7N/OVPJhuOcrec/lZklfoOEEAABFw5FOAAAARMdCIgAAAERXV9NpZlPN7GUz221mC/MaFIA4qFmgOKhXlE2PX143s4FKPvNxipIT0z8raZa7/zjrNsOHD/fRo0f3aH9Anvbu3auOjo7YZ4/pV2qtWeoV/QX1yhyLYsmq2XoWEk1U8mHDeyTJzNZJulrvPjvGu4wePVqtra117BLIR3Nzc18PoS/UVLPUK/oL6pU5FsWSVbP1vLw+SskHrZ60L80A9E/ULFAc1CtKp56mM/RSR5fX6s3sRjNrNbPW9vbSfKg+UETd1iz1CvQbzLEonXqazn1KPj/ypPMkvdZ5I3df6e7N7t7c1NRUx+4A1KnbmqVegX6DORalU0/T+aykMWb2ATMbLOkzkh7LZ1gAIqBmgeKgXlE6PV5I5O7HzWy+pO9LGijpfnf/UW4jA5ArahYoDuoVZVTXaTDd/XuSvpfTWABERs0CxUG9omw4IxEAAACio+kEAABAdDSdAAAAiI6mEwAAANHRdAIAACA6mk4AAABER9MJAACA6Gg6AQAAEB1NJwAAAKKj6QQAAEB0NJ0AAACIjqYTAAAA0dF0AgAAIDqaTgAAAER3Wl8PALV76623gvkLL7wQzDdt2hTM161b1yXbsWNHcNuFCxcG85kzZwbzCRMmBHMAAOoVmgfzmAMl5sGYONIJAACA6Gg6AQAAEB1NJwAAAKKj6QQAAEB0NJ0AAACIrq7V62a2V9IhSb+WdNzdm/MYFE4ta8Xdww8/HMxPnDgRzM8777wumZkFt/3KV74SzO+5555gnrVacNKkScEcvYOaBYqDes0WmgfzmAOluPNgo8+BeXxk0pXu3pHD/QDoHdQsUBzUK0qDl9cBAAAQXb1Np0v6FzN7zsxuzGNAAKKiZoHioF5RKvW+vD7Z3V8zsxGSNprZS+6+tXKDtFBulKQLLrigzt0BqNMpa5Z6BfoV5liUSl1HOt39tfT7AUmPSpoY2Galuze7e3NTU1M9uwNQp+5qlnoF+g/mWJRNj490mtlQSQPc/VD68x9IWpLbyKA9e/YE86NHjwbzrBV67h7Mly5d2iUbPHhwcNsbbrghmD/11FPBfM6cOcH86aefDuZDhgwJ5sgPNQsUB/WaqGUezGMOlOLOg40+B9bz8vpISY+mHy1wmqQ17v54LqMCEAM1CxQH9YrS6XHT6e57JH0wx7EAiIiaBYqDekUZ8ZFJAAAAiI6mEwAAANHRdAIAACC6PE6DiTqtWLEimC9YsCCYHzlypKb7z1q5t2RJ14WQa9asCW67atWqYD5u3LhgvmPHjmC+aNGiYJ51TlsAQPnFnAdrmQOluPNgo8+BHOkEAABAdDSdAAAAiI6mEwAAANHRdAIAACA6mk4AAABEx+r1XrR9+/ZgPnfu3GA+YED4OcH06dOD+axZs4L5jBkzgvmhQ4eCecjYsWODeda5brPGfvjw4ar3CZTBW2+9FcxfeOGFYL5p06Zgvm7dui5Z1qdELFy4MJjPnDkzmE+YMCGYA3mLOQ/GnAOlfObBRp8DOdIJAACA6Gg6AQAAEB1NJwAAAKKj6QQAAEB0NJ0AAACIjtXrkWzbtq1LNmXKlOC2Wavz7rrrrmB+0003BfMhQ4ZUObrE6aefXtP2IbNnzw7mDz30UDDv6OgI5sePHw/mp53GryiKLbTqXJIefvjhYJ61Eva8887rkplZcNus8zjfc889wTxrxfykSZOCOdCd0BwoxZ0H+2IOlGqbBxt9DuRIJwAAAKKj6QQAAEB0NJ0AAACIjqYTAAAA0XXbdJrZ/WZ2wMx2VmTvM7ONZrYr/T4s7jABVIuaBYqDekUjMXc/9QZmH5F0WNID7v47afZVSW+6+5fNbKGkYe7+l93trLm52VtbW3MYdv+RtUJv2rRpXbKDBw8Gt/3kJz8ZzLNWvda6Qi+mV155JZhfdNFFNd3Pm2++GczPOuusmsdUjebmZrW2toaX/hZcXjVbxnqNac+ePcH8ySefDOYbNmwI5ll/k5ctW9YlGzx4cHDbG264IZg/9dRTwfziiy8O5k8//XQw7+2/QdRr/51ja5kDJebBLL09B8aWVbPdHul0962SOv9rXC1pdfrzaknX1D1CALmgZoHioF7RSHr6ns6R7v66JKXfR+Q3JAARULNAcVCvKKXoC4nM7EYzazWz1vb29ti7A1AH6hUoFmoWRdLTpvMNMztHktLvB7I2dPeV7t7s7s1NTU093B2AOlVVs9Qr0C8wx6KUenp+pcckzZb05fT7d3IbUcGsXr06mGe9WTpk/fr1wTyvU3TFNGwYiyoLgprNyYoVK4L5ggULgvmRI0dquv+shURLlizpkq1Zsya47apVq4L5uHHjgvmOHTuC+aJFi4J51mk2kZvC1Gsec6DEPNgoqvnIpLWS/k3SfzazfWb2WSWFMMXMdkmakl4G0A9Qs0BxUK9oJN0e6XT3WRlXfSznsQDIATULFAf1ikbCGYkAAAAQHU0nAAAAoqPpBAAAQHQ9Xb3ecK688spgvnXr1mA+efLkLtnGjRuD2xZhdV6tuju9KtDfbd++PZjPnTs3mA8YEH4OP3369GA+a1b4rXwzZswI5ocOHQrmIWPHjg3mJ06cCOZZYz98+HDV+0S55TEHSsyDjY4jnQAAAIiOphMAAADR0XQCAAAgOppOAAAAREfTCQAAgOhYvd7Jq6++GsyzVuidffbZwXzp0qVdsjKuztuyZUswN7NgPmLEiGA+cODA3MYE1Grbtm1dsilTpgS3zVrpfddddwXzm266KZgPGTKkytEl8vj7MXv27GD+0EMPBfOOjo5gfvz48WB+2mlMKWUQmgfzmAMl5sFGnwM50gkAAIDoaDoBAAAQHU0nAAAAoqPpBAAAQHQ0nQAAAIiuYZcatrW1BfPm5uaa7mfVqlXBfOLEiTWPqYjmz59f0/ZZ55UeOnRoHsMBTim0Sl2Spk2b1iU7cuRIcNusc6nntUo9psWLFwfzrNXrjzzySDDP+rt31lln9Wxg6BN5zIONPgdKtc2DjT4HcqQTAAAA0dF0AgAAIDqaTgAAAERH0wkAAIDoum06zex+MztgZjsrskVmtt/Mnk+/PhF3mACqRc0CxUG9opFUs3q9RdI3JD3QKV/m7uGTqxbAkiVLgvnRo0eD+XXXXRfMr7rqqtzG1J9lnZN+3759wTzr3Ovz5s3LbUzI1KIS1mweVq9eHcwPHjxY9X2sX78+mBfhnNLDhg3r6yGgqxb1Ub3WMg82+hwo5TMPNvoc2O2RTnffKunNXhgLgBxQs0BxUK9oJPW8p3O+mb2YvjTA02eg/6NmgeKgXlE6PW06vyXpIkmXSHpd0teyNjSzG82s1cxa29vbe7g7AHWqqmapV6BfYI5FKfWo6XT3N9z91+5+QtIqSZmnHnD3le7e7O7NTU1NPR0ngDpUW7PUK9D3mGNRVj1qOs3snIqLMyTtzNoWQN+jZoHioF5RVt2uXjeztZI+Kmm4me2T9CVJHzWzSyS5pL2SPh9xjHVpaWkJ5lkrUGu9n7LZv39/ML/44ouDubsH82uvvTaYjx8/vmcDQ9WKXrN5uPLKK4P51q1bg/nkyZO7ZBs3bgxuW4RV6rXKqmPE1xv1msc82ChzoBR3Hmz0ObDbptPdZwXi+yKMBUAOqFmgOKhXNBLOSAQAAIDoaDoBAAAQHU0nAAAAoqPpBAAAQHTVnHu90J555plgnnVu8Kzzy5bR2rVru2QPPvhgcNtDhw4F80mTJgXzFStW9HxgQJWyzoWctUr97LPPDuZLl3Y9xXUZV6lv2bIlmGf9PRwxYkQwHzhwYG5jQnzMg2GhOVBiHoyJI50AAACIjqYTAAAA0dF0AgAAIDqaTgAAAERX+oVEtRo3blxfDyF3ixcvDuahNzm3t7cHt806/VfWqQLPOOOMKkcHdK+trS2YNzc313Q/q1atCuYTJ06seUxFNH/+/Jq2nzFjRjAfOnRoHsNBP1W2ebCWOVBiHoyJI50AAACIjqYTAAAA0dF0AgAAIDqaTgAAAERH0wkAAIDoSrN6/ejRo8F8586dvTyS+Do6OoL51KlTg/n27duDeegUaFmr8zZv3hzMWZ2H3rBkyZJgnlX3Wafxu+qqq3IbU3+WdXrQffv2BfOs0yHOmzcvtzEhPncP1kSjzIN5zIES82BMHOkEAABAdDSdAAAAiI6mEwAAANHRdAIAACA6mk4AAABE1+3qdTM7X9IDkt4v6YSkle7+dTN7n6SHJY2WtFfSp939F/GGemrHjh0L5lmr1tw9mD/++OPBfO7cucH8zDPPrGJ0p3bgwIFgfvvttwfz5cuX13T/F154YTC/8847u2TTp08PbsvqvGIoSr1maWlpCebr16/P5X7KZv/+/cE8a/Vt1t+9a6+9NpiPHz++ZwND1fKsWXcPzoV5zIMx50Ap7jxYyxwoMQ/GVM2RzuOSvuju4yR9SNI8MxsvaaGkJ9x9jKQn0ssA+hb1ChQLNYuG0W3T6e6vu/sP058PSWqTNErS1ZJWp5utlnRNrEECqA71ChQLNYtGUtN7Os1stKQJkrZJGunur0tJ0UgakXGbG82s1cxa29vb6xstgKpRr0Cx1FuzWScOAfqLqptOMztD0iOSbnH3g9Xezt1Xunuzuzc3NTX1ZIwAakS9AsWSR80OHz483gCBHFTVdJrZICXF8G1335DGb5jZOen150gKvwsYQK+iXoFioWbRKKpZvW6S7pPU5u53V1z1mKTZkr6cfv9OlBFWKWsFXdY5mFeuXBnMt27dGszPPffcmu6/Fvfee28wzzov7MiRI4P5pZdeGsyzVvLyrLh8ilKvWZ555plgnlULedRfUaxdu7ZL9uCDDwa3PXToUDCfNGlSMF+xYkXPB4a65FmzAwYMCM6FecyDMedAKZ95kDmw/+u26ZQ0WdJ1knaY2fNpdpuSQlhvZp+V9DNJfxJniABqQL0CxULNomF023S6+1OSwk81pI/lOxwA9aBegWKhZtFIOCMRAAAAoqPpBAAAQHQ0nQAAAIiumoVEhXbLLbcE80cffTSYZ30g9pEjR4J51uq/PEydOjWYL1u2LJiPGTMm2liA/mjcuHF9PYTcLV68OJiHVphn/b3KOvf6xo0bgznnlC63PObBvpgDpdrmQebA/o8jnQAAAIiOphMAAADR0XQCAAAgOppOAAAAREfTCQAAgOhKv3p97NixwXz37t3B/KWXXgrmGzZsCOZr1qwJ5nPnzq1idIk5c+YE8yFDhgTzQYMGVX3fQH909OjRYL5z585eHkl8HR0dwTxrVe727duDeegc1Fmr1Ddv3hzMWaXemPKYB2POgRLzYKPgSCcAAACio+kEAABAdDSdAAAAiI6mEwAAANHRdAIAACC60q9ezzJ06NBgftlll9WU33HHHbmNCWgUx44dC+ZZK7fdPZg//vjjwTxr5eyZZ55ZxehO7cCBA8H89ttvD+bLly+v6f4vvPDCYH7nnXd2yaZPnx7cllXqqEYt8yBzIPLAkU4AAABER9MJAACA6Gg6AQAAEB1NJwAAAKLrdiGRmZ0v6QFJ75d0QtJKd/+6mS2S9DlJ7emmt7n792INFED3ilKvWQt6rrvuumC+cuXKYL5169Zgfu6559Z0/7W49957g3noNJWSNHLkyGB+6aWXBvOWlpZgPnz48O4Hh8IpSs0Ceahm9fpxSV909x+a2ZmSnjOzjel1y9x9abzhAagR9QoUCzWLhtFt0+nur0t6Pf35kJm1SRoVe2AAake9AsVCzaKR1PSeTjMbLWmCpG1pNN/MXjSz+81sWMZtbjSzVjNrbW9vD20CIALqFSgWahZlV3XTaWZnSHpE0i3uflDStyRdJOkSJc/Svha6nbuvdPdmd29uamrKYcgAukO9AsVCzaIRVNV0mtkgJcXwbXffIEnu/oa7/9rdT0haJWlivGECqBb1ChQLNYtGUc3qdZN0n6Q2d7+7Ij8nfS+KJM2QtDPOEAFUq+j1essttwTzRx99NJhnvZx45MiRYJ61Cj4PU6dODebLli0L5mPGjIk2FhRH0WsWqEU1q9cnS7pO0g4zez7NbpM0y8wukeSS9kr6fJQRAqgF9QoUCzWLhlHN6vWnJIU+gI7PCwP6GeoVKBZqFo2EMxIBAAAgOppOAAAAREfTCQAAgOiqWUgEAL1i7NixwXz37t3B/KWXXgrmGzZsCOZr1qwJ5nPnzq1idIk5c+YE8yFDhgTzQYMGVX3fAFBmHOkEAABAdDSdAAAAiI6mEwAAANHRdAIAACA6mk4AAABEZ+7eezsza5f0SnpxuKSOXtt53+Fx9k8XuntTXw+iP6NeS69Ij5V6rQI1W2pFe5zBmu3VpvNdOzZrdffmPtl5L+Jxogwa5f+3UR6n1FiPtRE1yv8vj7NYeHkdAAAA0dF0AgAAILq+bDpX9uG+exOPE2XQKP+/jfI4pcZ6rI2oUf5/eZwF0mfv6QQAAEDj4OV1AAAARNfrTaeZTTWzl81st5kt7O39x2Rm95vZATPbWZG9z8w2mtmu9PuwvhxjHszsfDP7gZm1mdmPzOzmNC/dY0V5a5Z6Ld9jRXnrVWqMmi17vfZq02lmAyV9U9I0SeMlzTKz8b05hshaJE3tlC2U9IS7j5H0RHq56I5L+qK7j5P0IUnz0v/HMj7Whlbymm0R9Vq2x9rQSl6vUmPUbKnrtbePdE6UtNvd97j7MUnrJF3dy2OIxt23SnqzU3y1pNXpz6slXdOrg4rA3V939x+mPx+S1CZplEr4WFHemqVey/dYUd56lRqjZster73ddI6S9GrF5X1pVmYj3f11KfllkjSij8eTKzMbLWmCpG0q+WNtUI1Ws6X+HaZeS6/R6lUq8e9xGeu1t5tOC2Qsny8oMztD0iOSbnH3g309HkRBzZYE9doQqNeSKGu99nbTuU/S+RWXz5P0Wi+Pobe9YWbnSFL6/UAfjycXZjZISUF82903pHEpH2uDa7SaLeXvMPXaMBqtXqUS/h6XuV57u+l8VtIYM/uAmQ2W9BlJj/XyGHrbY5Jmpz/PlvSdPhxLLszMJN0nqc3d7664qnSPFQ1Xs6X7HaZeG0qj1atUst/jstdrr384vJl9QtLfSxoo6X53v6NXBxCRma2V9FFJwyW9IelLkv5J0npJF0j6maQ/cffOb4QuFDO7QtL/kbRD0ok0vk3J+05K9VhR3pqlXqnXMiprvUqNUbNlr1fOSAQAAIDoOCMRAAAAoqPpBAAAQHQ0nQAAAIiOphMAAADR0XQCAAAgOppOAAAAREfTCQAAgOhoOgEAABDd/wccbMWsmyuSlAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 864x216 with 3 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "\n",
    "\n",
    "\n",
    "image = X_train[1000]\n",
    "shifted_image_down = shift_image(image, 0, 4)\n",
    "shifted_image_left = shift_image(image, -4, 0)\n",
    "\n",
    "plt.figure(figsize=(12,3))\n",
    "plt.subplot(131)\n",
    "plt.title(\"Original\", fontsize=14)\n",
    "plt.imshow(image.reshape(28, 28), interpolation=\"nearest\", cmap=\"Greys\")\n",
    "plt.subplot(132)\n",
    "plt.title(\"Shifted down\", fontsize=14)\n",
    "plt.imshow(shifted_image_down.reshape(28, 28), interpolation=\"nearest\", cmap=\"Greys\")\n",
    "plt.subplot(133)\n",
    "plt.title(\"Shifted left\", fontsize=14)\n",
    "plt.imshow(shifted_image_left.reshape(28, 28), interpolation=\"nearest\", cmap=\"Greys\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_augmented = [image for image in X_train]\n",
    "y_train_augmented = [label for label in y_train]\n",
    "\n",
    "for dx, dy in ((1, 0), (-1, 0), (0, 1), (0, -1)):\n",
    "    for image, label in zip(X_train, y_train):\n",
    "        X_train_augmented.append(shift_image(image, dx, dy))\n",
    "        y_train_augmented.append(label)\n",
    "\n",
    "X_train_augmented = np.array(X_train_augmented)\n",
    "y_train_augmented = np.array(y_train_augmented)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "shuffle_idx = np.random.permutation(len(X_train_augmented))\n",
    "X_train_augmented = X_train_augmented[shuffle_idx]\n",
    "y_train_augmented = y_train_augmented[shuffle_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "knn_clf = KNeighborsClassifier(n_neighbors=4,weights='distance',n_jobs=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "knn_clf = KNeighborsClassifier(n_neighbors=4,weights='distance',n_jobs=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n",
       "           metric_params=None, n_jobs=-1, n_neighbors=4, p=2,\n",
       "           weights='distance')"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "knn_clf.fit(X_train_augmented, y_train_augmented)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = knn_clf.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9763"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "accuracy_score(y_test, y_pred)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
